{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 千帆 function call 入门\n",
    "\n",
    "## 简介\n",
    "\n",
    "function_call，顾名思义，通过给大模型提供 function 的说明描述，以及对应的入参出参 schema，让大模型输出 function 调用策略，结合多轮对话，以最终实现一个复杂的任务。\n",
    "以下将以获取数据库中某类文件的数量为例子，通过调用千帆 Python SDK提供的 ERNIE-Bot 大模型以得到数据库中该语言的文件数量。\n",
    "\n",
    "## 准备\n",
    "\n",
    "本文使用了千帆 Python SDK中的 chat_completion 模块，该模块提供了与千帆对话引擎的交互接口，目前支持function call的模型有ERNIE-Bot与ERNIE-Bot-4。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a2ef73d1b5366c1d"
  },
  {
   "cell_type": "markdown",
   "source": [
    "首先安装千帆 Python SDK，版本号 >= 0.2.6。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f37417f96d78b87e"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install \"qianfan>=0.2.6\" -U"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:40.175335Z",
     "start_time": "2024-01-15T05:06:37.775010Z"
    }
   },
   "id": "5779c8b36b65d730",
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "source": [
    "初始化我们所需要的凭证"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9e26bd0785ed4ddb"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# 初始化LLM\n",
    "import os\n",
    "\n",
    "# qianfan sdk 鉴权\n",
    "# os.environ[\"QIANFAN_ACCESS_KEY\"]=\"...\"\n",
    "# os.environ[\"QIANFAN_SECRET_KEY\"]=\"...\"\n"
   ],
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:40.179371Z",
     "start_time": "2024-01-15T05:06:40.177163Z"
    }
   },
   "id": "initial_id",
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "source": [
    "自定义一个给chat调用的函数，此处以获取数据库中特定语言撰写的文件数量为例"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "164bd2b4bea750b3"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "def get_file_num(language: str) -> str:\n",
    "    \"\"\"获取数据库中指定语言的代码文件数量\"\"\"\n",
    "    language_low = language.lower()\n",
    "    language_map = {\n",
    "        \"c/c++\": 35,\n",
    "        \"java\": 10,\n",
    "        \"javascript\": 25,\n",
    "        \"python\": 35,\n",
    "        \"go\": 32,\n",
    "    }\n",
    "    return str(language_map.get(language_low, 0))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:40.185363Z",
     "start_time": "2024-01-15T05:06:40.181918Z"
    }
   },
   "id": "9af388291ce5434b",
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "source": [
    "描述函数调用策略\n",
    "其中，name为函数名称，description为函数描述，properties为入参schema，required为必填参数\n",
    "properties支持json schema格式，具体参考[json schema](https://json-schema.org/learn/getting-started-step-by-step.html\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "598094c5ec2b987c"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "func_list = [{\n",
    "    \"name\": \"get_file_num\",  # 函数名称\n",
    "    \"description\": \"获取内部数据库中以某一编程语言编写的文件数量\",  # 函数描述\n",
    "    \"parameters\":{\n",
    "        \"type\":\"object\",\n",
    "        \"properties\":{  # 参数schema，如果参数为空，设为空字典即可\n",
    "            \"language\":{  # 参数名称\n",
    "                \"type\":\"string\",  # 参数类型\n",
    "                \"description\": \"代码所运用的编程语言，例如：python、c/c++、go、java\"  # 参数描述\n",
    "            }\n",
    "        },\n",
    "        \"required\":[\"language\"]  # 必填参数（无默认值）\n",
    "    }\n",
    "}]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:40.189650Z",
     "start_time": "2024-01-15T05:06:40.185262Z"
    }
   },
   "id": "e883e0b92ccecade",
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "source": [
    "创建 chat_completion 对象,调用 do 方法进行交互"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "42435501fc3cf5af"
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "import qianfan\n",
    "import json\n",
    "\n",
    "chat_comp = qianfan.ChatCompletion(model=\"ERNIE-Bot\")  # 指定模型，目前ERNIE-Bot和ERNIE-Bot4.0支持function call\n",
    "query = \"请帮我查询一下数据库中用go撰写的代码文件数量\"\n",
    "msgs = qianfan.QfMessages()\n",
    "msgs.append(query,role='user')\n",
    "resp = chat_comp.do(\n",
    "    messages=msgs,\n",
    "    functions=func_list\n",
    ")\n",
    "print(resp['body']['result'])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:42.657807Z",
     "start_time": "2024-01-15T05:06:40.194225Z"
    }
   },
   "id": "ff5886e76c5a572f",
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "source": [
    "可以发现，此时chat的反馈为null，我们尝试打印返回值"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "91b0bcba42f3debe"
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "QfResponse(code=200, headers={'Access-Control-Allow-Headers': 'Content-Type', 'Access-Control-Allow-Origin': '*', 'Appid': '26217442', 'Connection': 'keep-alive', 'Content-Encoding': 'gzip', 'Content-Type': 'application/json; charset=utf-8', 'Date': 'Mon, 15 Jan 2024 05:06:42 GMT', 'P3p': 'CP=\" OTI DSP COR IVA OUR IND COM \"', 'Server': 'Apache', 'Set-Cookie': 'BAIDUID=24348C0A48AF1D67E33939F0A17EE737:FG=1; expires=Thu, 31-Dec-37 23:55:55 GMT; max-age=2145916555; path=/; domain=.baidu.com; version=1', 'Statement': 'AI-generated', 'Vary': 'Accept-Encoding', 'X-Aipe-Self-Def': 'eb_total_tokens:199,prompt_tokens:164,id:as-wpux8h3f0a', 'X-Baidu-Request-Id': 'sdk-py-0.2.7-evQ27Nto7272jjyB', 'X-Openapi-Server-Timestamp': '1705295200', 'Content-Length': '363'}, body={'id': 'as-wpux8h3f0a', 'object': 'chat.completion', 'created': 1705295202, 'result': '', 'is_truncated': False, 'need_clear_history': False, 'function_call': {'name': 'get_file_num', 'thoughts': '用户想要查询数据库中用go撰写的代码文件数量，我可以使用get_file_num工具来完成这个任务。', 'arguments': '{\"language\":\"go\"}'}, 'finish_reason': 'function_call', 'usage': {'prompt_tokens': 164, 'completion_tokens': 35, 'total_tokens': 199}}, statistic={'request_latency': 2.442524, 'total_latency': 2.4508387089999815})\n"
     ]
    }
   ],
   "source": [
    "print(resp)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:42.658288Z",
     "start_time": "2024-01-15T05:06:42.650216Z"
    }
   },
   "id": "9672df59be0cd560",
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "source": [
    "从thoughts可见chat确定需要调用函数get_file_num，我们需要将反馈作为function输入来进行二次对话"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "20eb94c22810a6ea"
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "根据我们的查询，数据库中用Go撰写的代码文件数量为32个。如果您还有其他问题或需要进一步的信息，请随时告诉我。\n"
     ]
    }
   ],
   "source": [
    "if  resp.get(\"function_call\"):\n",
    "    # 获取函数名称、入参及返回值\n",
    "    func_call_result = resp[\"function_call\"]\n",
    "    func_name = func_call_result[\"name\"]\n",
    "    language = json.loads(func_call_result[\"arguments\"]).get(\"language\")\n",
    "    func_resp = get_file_num(language)\n",
    "    \n",
    "    # 将函数返回值转换成json字符串\n",
    "    func_content = json.dumps({\n",
    "        \"return\":func_resp\n",
    "    })\n",
    "    \n",
    "    # 创建新的消息\n",
    "    msgs.append(resp, role=\"assistant\")\n",
    "    msgs.append(func_content, role=\"function\")\n",
    "    \n",
    "    # 再次调用chat_completion\n",
    "    second_resp = chat_comp.do(\n",
    "        messages=msgs,\n",
    "        functions=func_list\n",
    "    )\n",
    "    \n",
    "    print(second_resp['body']['result'])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-15T05:06:46.044416Z",
     "start_time": "2024-01-15T05:06:42.655427Z"
    }
   },
   "id": "80a90a0b3061bb21",
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "source": [
    "通过以上步骤，我们成功让chat调用了函数get_file_num，并得到了正确的返回值。\n",
    "接下来，我们可以尝试让chat调用千帆自带的工具函数\n",
    "[千帆function_call工具调用](./function_call_with_tool.ipynb)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "55d5d22bac9c54e0"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
